{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/home/leon/miniconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from keras.models import Model\n",
    "from keras.layers import Input\n",
    "from keras.layers.core import Dense\n",
    "from keras.layers.convolutional import Conv2D\n",
    "from keras.layers.wrappers import TimeDistributed\n",
    "from keras import backend as K\n",
    "import json\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_decimal(arr, places=6):\n",
    "    return [round(x * 10**places) / 10**places for x in arr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA = OrderedDict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### TimeDistributed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**[wrappers.TimeDistributed.0] wrap a Dense layer with units 4 (input: 3 x 6)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_seed = 2000\n",
    "data_in_shape = (3, 6)\n",
    "\n",
    "layer_0 = Input(shape=data_in_shape)\n",
    "layer_1 = TimeDistributed(Dense(4))(layer_0)\n",
    "model = Model(inputs=layer_0, outputs=layer_1)\n",
    "\n",
    "np.random.seed(random_seed)\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(random_seed + i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "\n",
    "DATA['wrappers.TimeDistributed.0'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**[wrappers.TimeDistributed.1] wrap a Conv2D layer with 6 3x3 filters (input: 5x4x4x2)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_seed = 2000\n",
    "data_in_shape = (5, 4, 4, 2)\n",
    "\n",
    "layer_0 = Input(shape=data_in_shape)\n",
    "layer_1 = TimeDistributed(Conv2D(6, (3,3), data_format='channels_last'))(layer_0)\n",
    "model = Model(inputs=layer_0, outputs=layer_1)\n",
    "\n",
    "np.random.seed(random_seed)\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(random_seed + i)\n",
    "    weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "\n",
    "DATA['wrappers.TimeDistributed.1'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### export for Keras.js tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "filename = '../../../test/data/layers/wrappers/TimeDistributed.json'\n",
    "if not os.path.exists(os.path.dirname(filename)):\n",
    "    os.makedirs(os.path.dirname(filename))\n",
    "with open(filename, 'w') as f:\n",
    "    json.dump(DATA, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"wrappers.TimeDistributed.0\": {\"input\": {\"data\": [0.141035, 0.129058, -0.023116, -0.327044, -0.248264, 0.064072, -0.863787, 0.169058, -0.524204, -0.678487, -0.695762, -0.745862, -0.345118, 0.388308, -0.282067, 0.782731, -0.59624, -0.778795], \"shape\": [3, 6]}, \"weights\": [{\"data\": [0.141035, 0.129058, -0.023116, -0.327044, -0.248264, 0.064072, -0.863787, 0.169058, -0.524204, -0.678487, -0.695762, -0.745862, -0.345118, 0.388308, -0.282067, 0.782731, -0.59624, -0.778795, 0.055114, 0.735311, -0.476251, -0.00121, -0.142871, 0.060008], \"shape\": [6, 4]}, {\"data\": [-0.665749, -0.838004, 0.920451, 0.8769], \"shape\": [4]}], \"expected\": {\"data\": [-0.435401, -0.729574, 0.891208, 0.435142, 0.449463, -0.303688, 1.418705, 0.491531, -0.206694, 0.102946, 0.646889, 1.393312], \"shape\": [3, 4]}}, \"wrappers.TimeDistributed.1\": {\"input\": {\"data\": [0.141035, 0.129058, -0.023116, -0.327044, -0.248264, 0.064072, -0.863787, 0.169058, -0.524204, -0.678487, -0.695762, -0.745862, -0.345118, 0.388308, -0.282067, 0.782731, -0.59624, -0.778795, 0.055114, 0.735311, -0.476251, -0.00121, -0.142871, 0.060008, 0.147894, -0.216289, -0.840972, 0.734562, -0.670993, 0.606963, -0.424144, -0.462858, 0.434956, 0.762811, 0.98424, -0.0833, 0.570259, 0.477388, -0.052834, -0.030331, 0.86601, 0.505308, -0.681422, -0.730379, -0.178646, 0.513073, -0.574974, -0.371941, -0.597453, 0.87685, 0.00883, 0.207463, 0.675097, 0.220365, 0.471146, -0.180468, -0.02072, 0.017849, 0.012965, 0.236682, 0.66921, 0.173131, -0.957385, 0.471247, 0.841267, 0.511354, -0.430488, 0.899198, 0.679766, 0.6299, 0.487356, 0.829739, 0.792468, -0.759192, -0.25087, -0.47263, -0.357234, 0.453933, 0.475895, -0.071184, 0.528075, -0.516599, 0.807998, 0.143212, -0.373437, -0.952796, 0.040104, 0.160803, -0.743346, 0.534637, 0.92295, -0.646198, 0.692993, 0.776042, -0.474673, 0.98543, 0.184755, -0.30293, -0.022748, 0.135056, 0.643602, 0.502182, 0.218371, -0.033313, 0.643199, 0.824619, -0.750012, 0.913739, 0.493227, -0.223239, 0.962789, -0.051118, -0.649982, 0.028419, 0.749459, 0.525479, -0.045743, -0.979086, -0.669124, -0.35604, -0.444705, 0.763697, 0.780956, 0.99184, 0.253779, -0.425107, 0.913943, 0.032233, 0.04514, 0.302513, 0.234363, -0.133738, 0.062949, -0.420645, 0.063063, 0.515796, -0.982415, 0.2082, 0.685465, 0.948766, 0.572355, 0.526904, -0.070458, 0.363551, 0.61422, 0.201742, 0.831489, 0.011985, -0.426671, 0.597698, -0.259001, 0.138813, 0.482189, -0.30467, 0.396641, 0.487531, 0.455809, 0.803298, 0.753098, 0.123971], \"shape\": [5, 4, 4, 2]}, \"weights\": [{\"data\": [0.141035, 0.129058, -0.023116, -0.327044, -0.248264, 0.064072, -0.863787, 0.169058, -0.524204, -0.678487, -0.695762, -0.745862, -0.345118, 0.388308, -0.282067, 0.782731, -0.59624, -0.778795, 0.055114, 0.735311, -0.476251, -0.00121, -0.142871, 0.060008, 0.147894, -0.216289, -0.840972, 0.734562, -0.670993, 0.606963, -0.424144, -0.462858, 0.434956, 0.762811, 0.98424, -0.0833, 0.570259, 0.477388, -0.052834, -0.030331, 0.86601, 0.505308, -0.681422, -0.730379, -0.178646, 0.513073, -0.574974, -0.371941, -0.597453, 0.87685, 0.00883, 0.207463, 0.675097, 0.220365, 0.471146, -0.180468, -0.02072, 0.017849, 0.012965, 0.236682, 0.66921, 0.173131, -0.957385, 0.471247, 0.841267, 0.511354, -0.430488, 0.899198, 0.679766, 0.6299, 0.487356, 0.829739, 0.792468, -0.759192, -0.25087, -0.47263, -0.357234, 0.453933, 0.475895, -0.071184, 0.528075, -0.516599, 0.807998, 0.143212, -0.373437, -0.952796, 0.040104, 0.160803, -0.743346, 0.534637, 0.92295, -0.646198, 0.692993, 0.776042, -0.474673, 0.98543, 0.184755, -0.30293, -0.022748, 0.135056, 0.643602, 0.502182, 0.218371, -0.033313, 0.643199, 0.824619, -0.750012, 0.913739], \"shape\": [3, 3, 2, 6]}, {\"data\": [-0.665749, -0.838004, 0.920451, 0.8769, 0.241718, -0.148796], \"shape\": [6]}], \"expected\": {\"data\": [-1.275021, -0.839414, 2.261197, 1.382413, -1.349136, -0.458723, 0.035198, 0.059575, 3.289305, -0.112961, 1.894473, 0.056991, 1.030516, -1.41503, 3.639513, 1.312105, 0.7178, 1.037027, -0.345448, -1.125389, 2.561248, 1.919182, 2.434602, -0.628632, -1.548767, -0.63035, 1.276973, 2.352928, 0.291083, -0.290137, -0.282871, -1.546028, 1.196641, 0.191184, -1.299763, -0.378339, -1.150905, -2.640953, 1.198463, 1.328465, 0.673755, 0.544141, -0.052356, -0.622702, 1.430426, 1.480166, 0.264735, 0.941609, -0.76932, -0.524748, -0.097613, -0.037621, 0.036755, 0.66339, -1.016433, -0.324838, -1.200336, 0.691055, 0.364328, -1.434147, -0.663102, -1.222518, 2.413622, 0.589465, 2.182118, 0.722914, 1.059836, -2.253524, 1.283919, 2.72919, -1.923285, 3.112732, 0.110403, -2.895395, -0.192991, 2.223259, 0.511342, 0.309826, -1.961463, -0.378234, -1.644435, 0.665466, -0.049178, -2.035013, -0.420975, -2.059799, 0.095355, -0.333549, -1.101064, 0.390433, -1.458754, -2.171488, -0.077547, 0.216551, -0.499852, -0.68166, -0.994649, -1.396949, 0.72207, 1.646976, -1.717063, 1.138718, 0.033565, -1.549827, 1.732043, 2.246031, 0.799031, 1.133765, -1.329821, -0.199264, 1.268272, 3.943292, -0.412988, 1.989369, 0.331454, -1.667789, 1.287873, 1.011695, -0.393542, 0.922799], \"shape\": [5, 2, 2, 6]}}}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(DATA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
